Skip to content

[TRAINIUM] improve support#421

Draft
mayank31398 wants to merge 29 commits intomainfrom
n
Draft

[TRAINIUM] improve support#421
mayank31398 wants to merge 29 commits intomainfrom
n

Conversation

@mayank31398
Copy link
Copy Markdown
Collaborator

No description provided.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
This reverts commit 0b7b3df.
This reverts commit 0afe9ab.
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for AWS Trainium accelerators, including specialized data type handling (int32), dynamic selection of compilation backends, and integration with the Neuron profiler. It also implements efficient model initialization and refactors positional embedding logic within the dense model mixins. Feedback focuses on correcting the usage of nullcontext, addressing potential RuntimeError and TypeError exceptions in the autocast and profiling logic due to platform-specific arguments, and improving the portability of utility scripts by removing hardcoded file paths.

Comment thread lm_engine/pretrain.py
device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype)
)
if args.distributed_args.fsdp_algorithm is None
else nullcontext
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

nullcontext is a class and must be instantiated (i.e., nullcontext()) to be used as a context manager. Using the class itself in a with statement will raise a TypeError because the class does not implement the context manager protocol (__enter__/__exit__) as class methods.

Suggested change
else nullcontext
else nullcontext()

Comment thread lm_engine/pretrain.py
enable_kernels(args.kernel_args.kernels),
(
torch.autocast(
device_type=Accelerator.get_device_type(), dtype=string_to_torch_dtype(args.mixed_precision_args.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

torch.autocast does not natively support device_type='xla' or 'neuron' in standard PyTorch. This will cause a RuntimeError when running on TPU or Trainium if the environment does not have a specifically patched version of PyTorch. For these accelerators, it is generally recommended to use device_type='cpu' (which is how Neuron AMP is typically triggered) or the accelerator-specific autocast context (e.g., torch_xla.amp.autocast).

repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(path),
experimental_config=experimental_config,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The experimental_config parameter is not part of the standard torch.profiler.profile signature in PyTorch. While it is supported by torch-neuronx, passing it (even as None) will cause a TypeError on standard PyTorch installations (e.g., when running on CUDA or CPU). To maintain cross-platform compatibility, consider using a conditional approach or dictionary unpacking to call profile without this argument on non-Trainium devices.

@@ -0,0 +1 @@
ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script contains a hardcoded absolute path to a personal PEM file (~/Desktop/mayank-melbourne.pem). This makes the script non-portable and potentially exposes details about your local file system. It is recommended to use an environment variable or a generic placeholder.

Suggested change
ssh -i ~/Desktop/mayank-melbourne.pem -L 8001:localhost:3001 -L 8002:localhost:3002 trainium-melbourne -fN
ssh -i ${PEM_FILE:-/path/to/your/key.pem} -L 8001:localhost:3001 -L 8002:localhost:3002 ${REMOTE_HOST:-trainium-melbourne} -fN

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant